/* Copyright (C) 2015-2018 RealVNC Ltd.  All Rights Reserved.
 */

#include <vnccommon/StringUtils.h>

#include <cstdio>
#include <cstdlib>
#include <limits>
#include <stdarg.h>
#include <limits.h>

#ifndef WINCE
#include <errno.h>
#endif

#ifdef WINCE
#define vsnprintf _vsnprintf
#endif

#if defined(WIN32) || defined(WINCE)
#define va_copy(dest, src) (dest = src)
#endif

using namespace vnccommon;


std::vector<std::string> StringUtils::split(
        const std::string& str,
        const char delimiter)
{
    return StringUtils::split(str, delimiter, Optional<size_t>());
}

std::vector<std::string> StringUtils::split(
        const std::string& str,
        const char delimiter,
        const vnccommon::Optional<size_t> maxElemToSplit)
{
    std::vector<std::string> result;

    if (str.empty())
    {
        return result;
    }
    else if (maxElemToSplit.hasValue() && maxElemToSplit.value() == 0)
    {
        throw std::runtime_error("Invalid value 0 for maxElemToSplit");
    }
    else if (maxElemToSplit.hasValue() && maxElemToSplit.value() == 1)
    {
        result.push_back(str);
        return result;
    }

    size_t start = 0;
    do
    {
        size_t end = str.find(delimiter, start);
        result.push_back(str.substr(start, end - start));

        start = end + 1;
    }
    while (start
            && !(maxElemToSplit.hasValue()
                && result.size() >= maxElemToSplit.value() - 1));

    if (start)
    {
        // We stopped splitting because the limit of splits was reached, so add
        // the remaining bits of the string as the last element.
        result.push_back(str.substr(start, str.size() - start));
    }

    return result;
}

std::vector<std::string> StringUtils::chunks(
        const std::string& str,
        const size_t chunkSize)
{
    std::vector<std::string> result;

    const size_t chunkCount = (str.size() + (chunkSize - 1)) / chunkSize;

    for(size_t i = 0; i < chunkCount; i++)
    {
        result.push_back(str.substr(
                i * chunkSize,
                chunkSize));
    }

    return result;
}


void StringUtils::replace(
        std::string& str1,
        const std::string& str2,
        const std::string& str3,
        const vnccommon::Optional<size_t> n)
{
    if (n.hasValue() && n.value() == 0)
    {
        return;
    }

    size_t start = 0;
    size_t replacements = 0;

    while((start = str1.find(str2, start)) != std::string::npos)
    {
        str1.replace(start, str2.length(), str3);
        // Increment to cover the case in which str3 is a substring of str2
        start += str3.length();
        replacements++;
        if (n.hasValue() && replacements >= n.value())
        {
            break;
        }
    }
}

bool StringUtils::startsWith(
        const std::string& str,
        const std::string& prefix)
{
    return 0 == str.compare(
                0,
                prefix.size(),
                prefix);
}

bool StringUtils::endsWith(
        const std::string& str,
        const std::string& suffix)
{
    if (suffix.size() > str.size())
    {
        return false;
    }

    return 0 == str.compare(
            str.size() - suffix.size(),
            std::string::npos,
            suffix);
}

bool StringUtils::wildcardMatch(
        const std::string& str,
        const std::string& pattern)
{
    // Split the pattern into a series of tokens, separated by wildcards, then
    // ensure that each token appears in the string in order, with the first
    // token matching the start of the string and the last token matching the
    // end.

    size_t tokenStart = 0;
    size_t tokenEnd = pattern.find('*');

    if (tokenEnd == std::string::npos)
    {
        // No wildcards - this is a simple match.
        return str == pattern;
    }

    size_t tokenSize = tokenEnd;

    // Check that the start of the string matches the first token.
    if (0 != str.compare(0, tokenSize, pattern, 0, tokenSize))
    {
        return false;
    }

    // Check that all-but-one of the other tokens appear in order after the
    // first.
    size_t offset = tokenSize;
    tokenStart = tokenEnd + 1;
    tokenEnd = pattern.find('*', tokenStart);
    while (tokenEnd != std::string::npos)
    {
        tokenSize = tokenEnd - tokenStart;
        size_t pos = str.find(pattern.c_str() + tokenStart, offset, tokenSize);
        if (pos == std::string::npos)
        {
            return false;
        }

        offset = pos + tokenSize;
        tokenStart = tokenEnd + 1;
        tokenEnd = pattern.find('*', tokenStart);
    }

    // Check that we still have room for the final token.
    tokenSize = pattern.size() - tokenStart;
    if (tokenSize > str.size() - offset)
    {
        return false;
    }

    // Check that the final token comes at the end of the string.
    return 0 == str.compare(
            str.size() - tokenSize, tokenSize, pattern, tokenStart, tokenSize);
}

std::string StringUtils::format(const char *const format, ...)
{
    va_list ap;
    va_start(ap, format);

    const std::string result = vformat(format, ap);

    va_end(ap);

    return result;
}

std::string StringUtils::vformat(const char *const format, va_list ap)
{
    if (!format) return "";
    
    size_t size = 256;
    
    // keep doubling the size of the buffer until the string fits
    while (true)
    {
        std::vector<char> buf;
        buf.resize(size);
        char* rawBuf = &buf[0];
        
        // on standard compliant platforms vsnprintf returns chars that would be written
        // on windows platforms it returns error when truncated and if string fits
        // without null char it still returns success
        va_list apCopy;
        va_copy(apCopy, ap);
        int written = vsnprintf(rawBuf, size, format, apCopy);
        va_end(apCopy);
 
        if (written >= 0 && (size_t)written < size)
        {
            rawBuf[size - 1] = '\0';
            return std::string(rawBuf);
        }
        
        if (size == MAX_STRING_SIZE)
        {
            // to avoid truncating midway through UTF8 we remove the last character
            // UTF8 multibyte chars start with 11xxxxxx and continue with 10xxxxxx
            char* end = &rawBuf[size - 2];
            if (*end & 0x80)
            {
                while (end > &rawBuf[0])
                {
                    // keep moving pointer back until we hit a non continuation byte
                    if ((*end & 0xC0) != 0x80)
                    {
                        break;
                    }
                    --end;
                }
                *end = '\0';
            }
            else
            {
                rawBuf[size - 1] = '\0';
            }
            return std::string(rawBuf);
        }
        
        size *= 2;
    }
    return "";
}

Optional<long> StringUtils::toLong(const std::string& str, int base)
{
    char* status = NULL;
    const char *const strRaw = str.c_str();

#ifndef WINCE
    errno = 0;
#endif

    const long result = strtol(str.c_str(), &status, base);

    if(status == NULL || status == strRaw || status != strRaw + str.size())
    {
        return Optional<long>();
    }
#ifndef WINCE
    else if((result == LONG_MIN || result == LONG_MAX) && errno == ERANGE)
    {
        return Optional<long>();
    }
#endif
    else
    {
        return MakeOptional(result);
    }
}

Optional<unsigned long> StringUtils::toUnsignedLong(const std::string& str, int base)
{
    char* status = NULL;
    const char *const strRaw = str.c_str();

#ifndef WINCE
    errno = 0;
#endif

    const unsigned long result = strtoul(str.c_str(), &status, base);

    if(status == NULL || status == strRaw || status != strRaw + str.size())
    {
        return Optional<unsigned long>();
    }
#ifndef WINCE
    else if(result == ULONG_MAX && errno == ERANGE)
    {
        return Optional<unsigned long>();
    }
#endif
    else
    {
        return MakeOptional(result);
    }
}

Optional<vnc_int32_t> StringUtils::toInt32(const std::string& str, int base)
{
    // Compile-time check that the platform is sane
    StaticAssert<(sizeof(long) >= sizeof(vnc_int32_t))>::staticAssert();

    Optional<long> l = StringUtils::toLong(str, base);

    if (!l.hasValue())
    {
        return Optional<vnc_int32_t>();
    }

    if (l.value() > std::numeric_limits<vnc_int32_t>::max()
        || l.value() < std::numeric_limits<vnc_int32_t>::min())
    {
        return Optional<vnc_int32_t>();
    }

    return MakeOptional(static_cast<vnc_int32_t>(l.value()));
}

Optional<vnc_uint32_t> StringUtils::toUint32(const std::string& str, int base)
{
    StaticAssert<(sizeof(unsigned long) >= sizeof(vnc_int32_t))>::staticAssert();
    Optional<unsigned long> ul = toUnsignedLong(str, base);

    if (!ul.hasValue())
    {
        return Optional<vnc_uint32_t>();
    }

    if (ul.value() > std::numeric_limits<vnc_uint32_t>::max()
        || ul.value() < std::numeric_limits<vnc_uint32_t>::min())
    {
        return Optional<vnc_uint32_t>();
    }

    return MakeOptional(static_cast<vnc_uint32_t>(ul.value()));
}

std::string StringUtils::binToHex(const vnc_uint8_t* data, size_t dataLength)
{
    std::ostringstream out;
    out
            << std::hex
            << std::setfill('0')
            << std::uppercase;

    for(size_t i = 0; i < dataLength; i++)
    {
        out << std::setw(2) << static_cast<vnc_uint32_t>(data[i]);
    }

    return out.str();
}

Optional<vnc_uint8_t> StringUtils::hexCharToByte(const char hex)
{
    if(hex >= '0' && hex <= '9')
    {
        return Optional<vnc_uint8_t>(hex - '0');
    }

    if(hex >= 'a' && hex <= 'f')
    {
        return Optional<vnc_uint8_t>(10 + (hex - 'a'));
    }

    if(hex >= 'A' && hex <= 'F')
    {
        return Optional<vnc_uint8_t>(10 + (hex - 'A'));
    }

    return Optional<vnc_uint8_t>();
}

void StringUtils::hexToBin(
        std::vector<vnc_uint8_t>& output,
        const std::string& hexData)
{
    Optional<vnc_uint8_t> partialByte;

    for(size_t i = 0; i < hexData.size(); i++)
    {
        const Optional<vnc_uint8_t> thisChar = hexCharToByte(hexData[i]);

        if(thisChar.hasValue())
        {
            if(partialByte.hasValue())
            {
                output.push_back(
                        (partialByte.value() << 4) | thisChar.value());

                partialByte = Optional<vnc_uint8_t>();
            }
            else
            {
                partialByte = thisChar;
            }
        }
    }
}

std::vector<vnc_uint8_t> StringUtils::hexToBin(const std::string& hexData)
{
    std::vector<vnc_uint8_t> result;
    hexToBin(result, hexData);
    return result;
}

Optional<vnc_uint32_t> StringUtils::prefixedHexStringToUint32(const std::string& str)
{
    if(str.size() < 3 || //Too short for "0(x|X)[0-9a-fA-F]+"
        !(  vnccommon::StringUtils::startsWith(str,"0x") ||
            vnccommon::StringUtils::startsWith(str,"0X") ))
    {
        return Optional<vnc_uint32_t>();
    }
    return toUint32(str, 16);
}

Optional<vnc_int32_t> StringUtils::prefixedHexStringToInt32(const std::string& str)
{
    if(str.size() < 3 || //Too short for "0(x|X)[0-9a-fA-F]+"
        !(  vnccommon::StringUtils::startsWith(str,"0x") ||
            vnccommon::StringUtils::startsWith(str,"0X") ))
    {
        return Optional<vnc_int32_t>();
    }
    return toInt32(str, 16);
}

std::vector<std::string> StringUtils::extractDelimitedStrings(
        const std::string& input,
        const std::string& startDelimiter,
        const std::string& endDelimiter)
{
    if(input.size() < (startDelimiter.size() + endDelimiter.size()))
    {
        return std::vector<std::string>();
    }

    std::vector<std::string> result;

    size_t currentStartPos = 0;

    size_t startDelimiterStart;

    while((startDelimiterStart = input.find(startDelimiter, currentStartPos))
            != std::string::npos)
    {
        const size_t substrStartPos = startDelimiterStart + startDelimiter.size();

        const size_t endDelimiterStart
                = input.find(endDelimiter, substrStartPos);

        if(endDelimiterStart == std::string::npos)
        {
            break;
        }

        result.push_back(input.substr(
                substrStartPos,
                endDelimiterStart - substrStartPos));

        currentStartPos = endDelimiterStart + endDelimiter.size();
    }

    return result;
}

StringUtils::StringBuilder::StringBuilder()
{
}

StringUtils::StringBuilder::StringBuilder(const StringBuilder& o)
{
    mStream << o.mStream.str();
}

StringUtils::StringBuilder::operator std::string() const
{
    return toString();
}

StringUtils::StringBuilder::operator const char*()
{
    mString = mStream.str();
    return mString.c_str();
}

std::string StringUtils::StringBuilder::toString() const
{
    return mStream.str();
}

StringUtils::StringBuilder StringUtils::build()
{
    return StringBuilder();
}

Optional<std::string> StringUtils::toOptionalString(const char *const str)
{
    if(str == NULL)
    {
        return Optional<std::string>();
    }
    else
    {
        return Optional<std::string>(str);
    }
}

const char* StringUtils::toNullableCharPtr(
        const vnccommon::Optional<std::string>& value)
{
    if(value.hasValue())
    {
        return value.value().c_str();
    }
    else
    {
        return NULL;
    }
}


std::string StringUtils::trim(const std::string& str)
{
    std::string::const_iterator start;
    for (start = str.begin(); start != str.end(); ++start)
    {
        if (!::isspace(*start))
        {
            break;
        }
    }

    if (start == str.end())
    {
        return "";
    }

    std::string::const_reverse_iterator end;
    for (end = str.rbegin(); end != str.rend(); ++end)
    {
        if (!::isspace(*end))
        {
            break;
        }
    }

    return std::string(start, end.base());
}
